from __future__ import absolute_import, division, print_function

import os
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from random import random
from datetime import datetime

from functions import PIDController, Generator, calculate_t, MLP
from classic_auction import alpha_VCG_Mechanism, alpha_GSP_Mechanism
from models import Args, Score_VCG, Hyper_VCG

device = t.device("cuda" if t.cuda.is_available() else "cpu")

class Learner:
    """Two Player Auction Learner."""

    def __init__(self, args):
        self.args = args
        self.auct_model = Hyper_VCG(args)
        
        generator_vector = 4
        self.w_net = MLP([generator_vector, generator_vector*10, generator_vector*10, 4], t.tanh).to(device)
        self.b_net = MLP([generator_vector, generator_vector*10, generator_vector*10, 4], t.tanh).to(device)
        
        self.optimizer_w = t.optim.Adam(self.w_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.scheduler_w = t.optim.lr_scheduler.StepLR(self.optimizer_w, 1, gamma=0.9999, last_epoch=-1)
        
        self.optimizer_b = t.optim.Adam(self.b_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.scheduler_b = t.optim.lr_scheduler.StepLR(self.optimizer_b, 1, gamma=0.9999, last_epoch=-1)
        
    def update_auction(self):
        low, high = 0., 2. * np.random.uniform() + 0.5
        alpha, beta = 1.5 * np.random.uniform(), 8 * np.random.uniform()
        distribution = t.tensor([low, high, alpha, beta]).to(device)
        
        generator = Generator(self.args)
        ctr_ads, ctr_og = generator.generate_uniform(0, 1.), generator.generate_uniform(0, 2.5)
        cvr_ads, cvr_og = generator.generate_uniform(0, 0.17), generator.generate_uniform(0, 0.17)
        train_data = generator.generate_uniform(low, high)
        
        loss = -t.mean(self.auct_model(train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, alpha, beta, self.w_net(distribution), self.b_net(distribution))[0])
        
        self.optimizer_w.zero_grad()
        self.optimizer_b.zero_grad()
        loss.backward()
        self.optimizer_w.step()
        self.optimizer_b.step()
        self.scheduler_w.step()
        self.scheduler_b.step()

def filter_pairs(x, y):
    """Filter pairs (x, y) to remove elements where (x_p > x_k and y_p > y_k)."""
    to_keep = np.ones(len(x), dtype=bool)
    
    for i in range(len(x)):
        if x[i] < 0 or y[i] < 0:
            to_keep[i] = False
            continue
        for j in range(len(x)):
            if i != j and x[j] > x[i] and y[j] > y[i]:
                to_keep[i] = False
                break
    
    return x[to_keep], y[to_keep]

def train_linear(args, args2, nets, rollouts, learner):
    hyper_losspr1, hyper_losscost1, hyper_lossclick1, hyper_losscvr1 = [0], [0], [0], [0]
    hyper_losspr2, hyper_losscost2, hyper_lossclick2, hyper_losscvr2 = [0], [0], [0], [0]
    hyper_perc1, hyper_perc2 = [0], [0]
    
    score_losspr1, score_losscost1, score_lossclick1, score_losscvr1 = [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]]


    generator1 = Generator(args)
    generator2 = Generator(args2)
    
    alpha, alpha2 = 0.5, 0.5
    beta, beta2 = 0.01, 0.01
    
    PID1 = PIDController(0.05, 0.002, 1., 0.5)
    PID12 = PIDController(0.5, 0.02, 1., 0.1)
    PID2 = PIDController(0.05, 0.002, 1., 0.5)
    PID22 = PIDController(0.5, 0.02, 1, 0.1)

    for i in range(rollouts):
        if i < 100:
            learner.update_auction()
            low, high = 0., 0.5 + 0.5 * 1
            distribution = t.tensor([low, high, 0.5, 0.1]).float().to(device)
            train_data2 = generator2.generate_uniform(low, high)
            ctr_ads2, ctr_og2 = generator2.generate_uniform(0, 1.), generator2.generate_uniform(0, 2.5)
            cvr_ads2, cvr_og2 = generator2.generate_uniform(0, 0.17), generator2.generate_uniform(0, 0.17)
            revenue, cost, click, perc, cvr = learner.auct_model(train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, alpha, beta, learner.w_net(distribution), learner.b_net(distribution))
            rev = revenue.cpu().detach().numpy()
            print(rev)

            for l in range(5):
                train_data = generator1.generate_uniform(0, 1.5)
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
                nets[l].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], 2)

            for l in range(5):
                high = calculate_t(i)
                train_data = generator1.generate_uniform(0, high)
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
                nets[l + 5].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], 2)
        else:
            for _ in range(10):
                learner.update_auction()

            low, high = 0., calculate_t(i)
            distribution = t.tensor([low, high, alpha, beta]).float().to(device)
            train_data2 = generator2.generate_uniform2(low, high)
            ctr_ads2, ctr_og2 = generator2.generate_uniform(0, 1.), generator2.generate_uniform(0, 2.5)
            cvr_ads2, cvr_og2 = generator2.generate_uniform(0, 0.17), generator2.generate_uniform(0, 0.17)
            revenue, cost, click, perc, cvr = learner.auct_model(train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, alpha, beta, learner.w_net(distribution), learner.b_net(distribution))
            rev, cost, click, cvr = revenue.cpu().detach().numpy(), cost.cpu().detach().numpy(), click.cpu().detach().numpy(), cvr.cpu().detach().numpy()
            hyper_losspr1.append(rev)
            hyper_losscost1.append(cost)
            hyper_lossclick1.append(click)
            hyper_losscvr1.append(cvr)
            hyper_perc1.append(perc)
            alpha *= np.exp(PID1.update(perc))
            beta *= np.exp(PID12.update(cvr))

            high = 1.5
            distribution = t.tensor([low, high, alpha2, beta2]).float().to(device)
            revenue, cost, click, perc, cvr = learner.auct_model(train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, alpha2, beta2, learner.w_net(distribution), learner.b_net(distribution))
            rev, cost, click, cvr = revenue.cpu().detach().numpy(), cost.cpu().detach().numpy(), click.cpu().detach().numpy(), cvr.cpu().detach().numpy()
            hyper_losspr2.append(rev)
            hyper_losscost2.append(cost)
            hyper_lossclick2.append(click)
            hyper_losscvr2.append(cvr)
            hyper_perc2.append(perc)
            alpha2 *= np.exp(PID2.update(perc))
            beta2 *= np.exp(PID22.update(cvr))

            for l in range(5):
                train_data = generator1.generate_uniform(0, 1.5)
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
                if i % 1 == 0:
                    nets[l].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], 2)
                    revenue, click, cost, cvr = nets[l](train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, [0.01, 0.1, 0.2, 0.5, 1][l], 2)
                    score_losspr1[l].append(revenue.cpu().detach().numpy())
                    score_losscost1[l].append(cost.cpu().detach().numpy())
                    score_lossclick1[l].append(click.cpu().detach().numpy())
                    score_losscvr1[l].append(cvr.cpu().detach().numpy())

            high = calculate_t(i)
            train_data = generator1.generate_uniform(0, high)
            for l in range(5):
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.15), generator1.generate_uniform(0, 0.15)
                if i % 1 == 0:
                    nets[l + 5].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], 2)
                    revenue, click, cost, cvr = nets[l + 5](train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, [0.01, 0.1, 0.2, 0.5, 1][l], 2)
                    score_losspr1[l+5].append(revenue.cpu().detach().numpy())
                    score_losscost1[l+5].append(cost.cpu().detach().numpy())
                    score_lossclick1[l+5].append(click.cpu().detach().numpy())
                    score_losscvr1[l+5].append(cvr.cpu().detach().numpy())

        if i % 2 == 0:
            print('i=', i)
        if i % 100 == 0 and i > 100:
            plot_results(hyper_losspr1, hyper_perc1, 'revenue', 'percentage')
        if i % 200 == 0 and i > 100:
            plot_experiment_results(generator1, nets, hyper_losspr1, hyper_losscost1, hyper_lossclick1, hyper_losscvr1, hyper_losspr2, hyper_losscost2, hyper_lossclick2, hyper_losscvr2, hyper_perc1, hyper_perc2, i, score_losspr1, score_losscost1, score_lossclick1, score_losscvr1)
        if i % 1000 == 0:
            save_models(learner, nets)

def plot_results(losspr, perc, label1, label2):
    plt.plot(losspr[-48:], label=label1)
    plt.plot(perc[-48:], label=label2)
    plt.show()

def plot_experiment_results(generator, nets, losspr2, losscost2, lossclick2, losscvr2, losspr4, losscost4, lossclick4, losscvr4, perc2, perc4, i, score_losspr1, score_losscost1, score_lossclick1, score_losscvr1):
    fixed_point = (np.mean(lossclick2[-48:]), np.mean(losscost2[-48:]))
    fixed_point2 = (np.mean(lossclick4[-48:]), np.mean(losscost4[-48:]))
    value_ads = generator.generate_uniform(0, 1.5)
    ctr_ads, ctr_og = generator.generate_uniform(0, 1.), generator.generate_uniform(0, 2.5)
    cvr_ads, cvr_og = generator.generate_uniform(0, 0.17), generator.generate_uniform(0, 0.17)
    x_point, y_point, z_point = np.zeros(8), np.zeros(8), np.zeros(8)
    x_point2, y_point2, z_point2 = np.zeros(8), np.zeros(8), np.zeros(8)
    x_point3, y_point3, z_point3 = np.zeros(8), np.zeros(8), np.zeros(8)
    x_point4, y_point4, z_point4 = np.zeros(8), np.zeros(8), np.zeros(8)

    for l in range(5):
        x_point[l], y_point[l], z_point[l] = alpha_VCG_Mechanism(value_ads, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], 2)
        x_point2[l], y_point2[l], z_point2[l] = alpha_GSP_Mechanism(value_ads, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], 2)
        x_point3[l] = np.mean(score_lossclick1[l][-48:])
        y_point3[l] = np.mean(score_losscost1[l][-48:])
        z_point3[l] = np.mean(score_losscvr1[l][-48:])
        x_point4[l] = np.mean(score_lossclick1[l+5][-48:])
        y_point4[l] = np.mean(score_losscost1[l+5][-48:])
        z_point4[l] = np.mean(score_losscvr1[l+5][-48:])

    print('x_point:', x_point)
    print('y_point:', y_point)
    print('z_point:', z_point)
    print('x_point2:', x_point2)
    print('y_point2:', y_point2)
    print('z_point2:', z_point2)
    print('x_point3:', x_point3)
    print('y_point3:', y_point3)
    print('z_point3:', z_point3)
    print('x_point4:', x_point4)
    print('y_point4:', y_point4)
    print('z_point4:', z_point4)

    x_point, y_point = filter_pairs(x_point, y_point)
    x_point2, y_point2 = filter_pairs(x_point2, y_point2)
    x_point3, y_point3 = filter_pairs(x_point3, y_point3)
    x_point4, y_point4 = filter_pairs(x_point4, y_point4)

    plt.figure(dpi=600)
    plt.plot(x_point, y_point, marker='o', linestyle='-', color='green', label='VCG ')
    plt.plot(x_point2, y_point2, marker='s', linestyle='-', color='blue', label='GSP')
    plt.plot(x_point3, y_point3, marker='v', linestyle='-', color='gray', label='SW-VCG (offline)')
    plt.plot(x_point4, y_point4, marker='x', linestyle='-', color='violet', label='SW-VCG (online)')
    plt.scatter(*fixed_point2, color='red', marker='p', label='AMMD (offline)')
    plt.scatter(*fixed_point, color='brown', marker='d', label='AMMD (online)')
    plt.xlabel('click')
    plt.ylabel('cost')
    plt.legend()
    plt.title('Experiments in dynamic environments')

    # Save the plot as PDF
    save_plot()

    plt.show()

    print('AMMD (online):', fixed_point)
    print('AMMD (online CVR):', np.mean(losscvr2[-48:]))
    print('AMMD (offline):', fixed_point2)
    print('AMMD (offline CVR):', np.mean(losscvr4[-48:]))
    print('orgs_perc:', np.mean(perc2[-48:]))
    print('orgs_perc_max:', np.max(perc2[-48:]))
    print('orgs_perc_min:', np.min(perc2[-48:]))
    print('orgs_percoff:', np.mean(perc4[-48:]))
    print('orgs_perc_maxoff:', np.max(perc4[-48:]))
    print('orgs_perc_minoff:', np.min(perc4[-48:]))

def save_plot():
    """Save the current plot to the imgs directory with a timestamp."""
    if not os.path.exists('imgs'):
        os.makedirs('imgs')
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    plt.savefig(f'imgs/plot_{timestamp}.pdf')

def save_models(learner, nets):
    """Save the current learner and nets to the checkpoints directory."""
    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')
    t.save(learner, 'checkpoints/learner.pth')
    for idx, net in enumerate(nets):
        t.save(net, f'checkpoints/net_{idx}.pth')

def load_models():
    """Load the learner and nets from the checkpoints directory."""
    learner = t.load('checkpoints/learner.pth')
    nets = [t.load(f'checkpoints/net_{idx}.pth') for idx in range(10)]
    return learner, nets

if __name__ == "__main__":
    # Set this flag to True to load models from checkpoints
    LOAD_FROM_CHECKPOINT = False

    args = Args((4, 1, "uniform", 10, 10, 100, 100, 1))
    args2 = Args((4, 1, "uniform", 10, 10, 1000, 1000, 1))
    nets = [Score_VCG(args) for _ in range(10)]

    if LOAD_FROM_CHECKPOINT:
        learner, nets = load_models()
    else:
        learner = Learner(args2)

    train_linear(args, args2, nets, rollouts=100001, learner=learner)
